syntax = "proto3";

package tensorflow.boosted_trees;
option cc_enable_arenas = true;
option java_outer_classname = "BoostedTreesProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

// Node describes a node in a tree.
message Node {
  oneof node {
    Leaf leaf = 1;
    BucketizedSplit bucketized_split = 2;
    CategoricalSplit categorical_split = 3;
    DenseSplit dense_split = 4;
  }
  NodeMetadata metadata = 777;
}

// NodeMetadata encodes metadata associated with each node in a tree.
message NodeMetadata {
  // The gain associated with this node.
  float gain = 1;

  // The original leaf node before this node was split.
  Leaf original_leaf = 2;
}

// Leaves can either hold dense or sparse information.
message Leaf {
  oneof leaf {
    // See third_party/tensorflow/contrib/decision_trees/
    // proto/generic_tree_model.proto
    // for a description of how vector and sparse_vector might be used.
    Vector vector = 1;
    SparseVector sparse_vector = 2;
  }
  float scalar = 3;
}

message Vector {
  repeated float value = 1;
}

message SparseVector {
  repeated int32 index = 1;
  repeated float value = 2;
}

message BucketizedSplit {
  // Float feature column and split threshold describing
  // the rule feature <= threshold.
  int32 feature_id = 1;
  int32 threshold = 2;
  // If feature column is multivalent, this holds the index of the dimension
  // for the split. Defaults to 0.
  int32 dimension_id = 5;
  enum DefaultDirection {
    // Left is the default direction.
    DEFAULT_LEFT = 0;
    DEFAULT_RIGHT = 1;
  }
  // default direction for missing values.
  DefaultDirection default_direction = 6;

  // Node children indexing into a contiguous
  // vector of nodes starting from the root.
  int32 left_id = 3;
  int32 right_id = 4;
}

message CategoricalSplit {
  // Categorical feature column and split describing the rule feature value ==
  // value.
  int32 feature_id = 1;
  int32 value = 2;

  // Node children indexing into a contiguous
  // vector of nodes starting from the root.
  int32 left_id = 3;
  int32 right_id = 4;
}

// TODO(nponomareva): move out of boosted_trees and rename to trees.proto
message DenseSplit {
  // Float feature column and split threshold describing
  // the rule feature <= threshold.
  int32 feature_id = 1;
  float threshold = 2;

  // Node children indexing into a contiguous
  // vector of nodes starting from the root.
  int32 left_id = 3;
  int32 right_id = 4;
}

// Tree describes a list of connected nodes.
// Node 0 must be the root and can carry any payload including a leaf
// in the case of representing the bias.
// Note that each node id is implicitly its index in the list of nodes.
message Tree {
  repeated Node nodes = 1;
}

message TreeMetadata {
  // Number of layers grown for this tree.
  int32 num_layers_grown = 2;

  // Whether the tree is finalized in that no more layers can be grown.
  bool is_finalized = 3;

  // If tree was finalized and post pruning happened, it is possible that cache
  // still refers to some nodes that were deleted or that the node ids changed
  // (e.g. node id 5 became node id 2 due to pruning of the other branch).
  // The mapping below allows us to understand where the old ids now map to and
  // how the values should be adjusted due to post-pruning.
  // The size of the list should be equal to the number of nodes in the tree
  // before post-pruning happened.
  // If the node was pruned, it will have new_node_id equal to the id of a node
  // that this node was collapsed into. For a node that didn't get pruned, it is
  // possible that its id still changed, so new_node_id will have the
  // corresponding id in the pruned tree.
  // If post-pruning didn't happen, or it did and it had no effect (e.g. no
  // nodes got pruned), this list will be empty.
  repeated PostPruneNodeUpdate post_pruned_nodes_meta = 4;

  message PostPruneNodeUpdate {
    int32 new_node_id = 1;
    float logit_change = 2;
  }
}

message GrowingMetadata {
  // Number of trees that we have attempted to build. After pruning, these
  // trees might have been removed.
  int64 num_trees_attempted = 1;
  // Number of layers that we have attempted to build. After pruning, these
  // layers might have been removed.
  int64 num_layers_attempted = 2;
  // The start (inclusive) and end (exclusive) ids of the nodes in the latest
  // layer of the latest tree.
  int32 last_layer_node_start = 3;
  int32 last_layer_node_end = 4;
}

// TreeEnsemble describes an ensemble of decision trees.
message TreeEnsemble {
  repeated Tree trees = 1;
  repeated float tree_weights = 2;

  repeated TreeMetadata tree_metadata = 3;
  // Metadata that is used during the training.
  GrowingMetadata growing_metadata = 4;
}

// DebugOutput contains outputs useful for debugging/model interpretation, at
// the individual example-level. Debug outputs that are available to the user
// are: 1) Directional feature contributions (DFCs) 2) Node IDs for ensemble
// prediction path 3) Leaf node IDs.
message DebugOutput {
  // Return the logits and associated feature splits across prediction paths for
  // each tree, for every example, at predict time. We will use these values to
  // compute DFCs in Python, by subtracting each child prediction from its
  // parent prediction and associating this change with its respective feature
  // id.
  repeated int32 feature_ids = 1;
  repeated float logits_path = 2;

  // TODO(crawles): return 2) Node IDs for ensemble prediction path 3) Leaf node
  // IDs.
}
